#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import time
import sys  
import os

# hackish way to resuse existing start code
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + 
        "/../tests/macro_benchmark/")
from databases import *
from clients import *

def avg_no_label_node_size(memgraph, client, measure_points):
    node_size_sum = 0
    NODE_CREATION = "UNWIND RANGE(1, {}) AS _ CREATE ()"

    for node_size in measure_points:
        memgraph.start()

        idle_memory = memgraph.database_bin.get_usage()['memory']
        client([NODE_CREATION.format(node_size)], memgraph)
        current_memory = memgraph.database_bin.get_usage()['memory'] 
        bytes_per_node = (current_memory - idle_memory) * 1024. / node_size
        node_size_sum += bytes_per_node
        memgraph.stop()

    return node_size_sum / len(measure_points)

def avg_label_on_node_size(memgraph, client, measure_points):
    label_on_node_size_sum  = 0
    NODE_CREATION = "UNWIND RANGE(1, {}) AS _ CREATE ()"
    NODE_WITH_LABEL_CREATION = "UNWIND RANGE(1, {}) AS _ CREATE (n{})"
    LABEL_SIZE = 5

    for node_cnt in measure_points:
        for label_count in range(1, LABEL_SIZE+1):
            memgraph.start()
            client([NODE_CREATION.format(node_cnt)], memgraph)
            idle_memory = memgraph.database_bin.get_usage()['memory']
            memgraph.stop()

            memgraph.start()
            labels = "".join(":L" + str(i) for i in range(label_count))
            client([NODE_WITH_LABEL_CREATION.format(node_cnt, labels)], memgraph)
            current_memory = memgraph.database_bin.get_usage()['memory'] 

            bytes_per_node = (current_memory - idle_memory) * 1024. / node_cnt / label_count
            label_on_node_size_sum += bytes_per_node

            memgraph.stop()

    return label_on_node_size_sum / (len(measure_points) * LABEL_SIZE)

def avg_edge_size(memgraph, client, measure_points):
    edge_size_sum = 0
    EDGE_CREATION = "CREATE (a) WITH a UNWIND range(1, {}) AS _ CREATE (a)-[:Edge]->(a)"

    for edge_cnt in measure_points:
        memgraph.start()

        idle_memory = memgraph.database_bin.get_usage()['memory']
        client([EDGE_CREATION.format(edge_cnt)], memgraph)
        current_memory = memgraph.database_bin.get_usage()['memory'] 
        bytes_per_edge = (current_memory - idle_memory) * 1024. / edge_cnt
        edge_size_sum += bytes_per_edge
        memgraph.stop()

    return edge_size_sum / len(measure_points)

def main():
    memgraph = Memgraph(None, 1)
    client = QueryClient(None, 1)

    measure_points = [10**6, 2*10**6, 5*10**6, 10*10**6]
    no_label_node_size = avg_no_label_node_size(memgraph, client, measure_points)
    label_on_node_size = avg_label_on_node_size(memgraph, client, measure_points)
    edge_size = avg_edge_size(memgraph, client, measure_points)
    
    print("Expected node without label size (Bytes): ", no_label_node_size)
    print("Expected label addition to node (Bytes): ", label_on_node_size)
    print("Expected edge size (Bytes): ", edge_size)

if __name__ == "__main__":
    main()
